Added is_causal mask argument to flax.nnx.dot_product_attention#5093
Added is_causal mask argument to flax.nnx.dot_product_attention#5093copybara-service[bot] merged 1 commit intogoogle:mainfrom
Conversation
|
Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else. |
Thanks! For parameterized tests, the idea is to write all 3 test cases as a parameterized single one. I do not think we need to parameterize on B, T, S etc. |
|
@ibbyml thanks for the updates! Please squash all commits into 1 otherwise CI will fail for num_commits >= 5 |
b06c6fc to
0c33425
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
@ibbyml let's keep only your updates |
Sorry about that. Accidentally pulled everything. New push should be squashed correctly. |
0c33425 to
9bab148
Compare
What does this PR do?
is_causalarg toflax.nnx.dot_product_attentionanddot_product_attention_weights.is_causalthrough tojax.nn.dot_product_attentionfast path when possible.is_causalwith input masks with thecombine_maskshelperChecklist